#core/formalization/rl/action_mask.py
from typing import List
import numpy as np

from utils.logger import Logger
from llm.llm_wrapper import LLMWrapper
from core.formalization.action_space import FormalizationAction

class ActionMask:

    def __init__(self, logger: Logger, llm: LLMWrapper, actions: List[FormalizationAction]):
        self.llm = llm
        self.logger = logger
        self.actions = actions

    def compute_action_mask(self, current_text: str, context=None):

        mask = np.zeros(len(self.actions), dtype=np.float32)
        for i, action in enumerate(self.actions):
            try:
                should_apply = action.should_apply(current_text, context)
                mask[i] = 1.0 if should_apply else 0.0
            except Exception as e:
                self.logger.log_exception(e)
                mask[i] = 0.0

        return mask

    def has_available_actions(self, mask: np.ndarray) -> bool:
        return np.sum(mask) > 0